{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4420e357-69be-4fdd-b21c-0508afaae8f2",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Fine-tuning a face mask detection model with Faster R-CNN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73ac4244-575c-4f4a-bbf0-bb940993f130",
   "metadata": {
    "tags": []
   },
   "source": [
    "This tutorial fine-tunes a pre-trained Faster R-CNN model from PyTorch to create a face mask detection model that detects if a person is wearing a face mask correctly, not wearing a mask, or wearing it incorrectly. This example demonstrates how to:\n",
    "* Use a dataset from Kaggle, with 853 annotated images in Pascal VOC format.\n",
    "* Parse the Pascal VOC XML annotations with Ray Data.\n",
    "* Retrieve images from S3 and attach them to the dataset.\n",
    "* Set up a distributed training loop using Ray Train.\n",
    "* Run inference and visualize detection results.\n",
    "* Save the final trained model for later use.\n",
    "\n",
    "This approach leverages transfer learning for efficient object detection and scales out distributed training using Ray on Anyscale.\n",
    "\n",
    "Here is the overview of the pipeline:\n",
    "\n",
    "<img\n",
    "  src=\"https://face-masks-data.s3.us-east-2.amazonaws.com/tutorial-diagrams/train_object_detection.png\"\n",
    "  alt=\"Object Detection Training Pipeline\"\n",
    "  style=\"width:75%;\"\n",
    "/>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9b48d29-f4ae-4781-b82d-314ce44e2e3a",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-warning\">\n",
    "  <b>Anyscale-specific configuration</b>\n",
    "  \n",
    "  <p>Note: This tutorial is optimized for the Anyscale platform. Running on open source Ray, requires additional configuration. For example, you need to manually:</p>\n",
    "  \n",
    "  <ul>\n",
    "    <li>\n",
    "      <b>Configure a Ray cluster:</b> Set up your multi-node environment, including head and worker nodes, and manage resource allocation, like autoscaling and GPU/CPU assignments, without the Anyscale automation. See <a href=\"https://docs.ray.io/en/latest/cluster/getting-started.html\">Ray Clusters</a> for details.\n",
    "    </li>\n",
    "    <li>\n",
    "      <b>Manage dependencies:</b> Install and manage dependencies on each node because you won’t have Anyscale’s Docker-based dependency management. See <a href=\"https://docs.ray.io/en/latest/ray-core/handling-dependencies.html\">Environment Dependencies</a> for instructions on installing and updating Ray in your environment.\n",
    "    </li>\n",
    "    <li>\n",
    "      <b>Set up storage:</b> Configure your own distributed or shared storage system instead of relying on Anyscale’s integrated cluster storage. See <a href=\"https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html\">Configuring Persistent Storage</a> for suggestions on setting up shared storage solutions.\n",
    "    </li>\n",
    "  </ul>\n",
    "\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "244a3d03-a8e4-41b7-8eaf-9de8274b9fb7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Set up dependencies\n",
    "\n",
    "Before proceeding, install the necessary dependencies. You have two options.\n",
    "\n",
    "### Option 1: Build a Docker image\n",
    "\n",
    "To set up an environment on Anyscale, you need to build a Docker image with the required dependencies. See the Anyscale docs for dependency management: https://docs.anyscale.com/configuration/dependency-management/dependency-byod/\n",
    "\n",
    "This workspace includes the `Dockerfile`. Feel free to build the image yourself on Anyscale. \n",
    "\n",
    "Using the Docker image may improve the workspace spin up time and worker node load time. \n",
    "\n",
    "**Note:** For open source Ray, use `rayproject/ray:2.41.0-py312-cu123` as the base image.\n",
    "\n",
    "\n",
    "### Option 2: Install libraries directly\n",
    "\n",
    "Alternatively, you can manually install the required libraries by following this guide:\n",
    "https://docs.anyscale.com/configuration/dependency-management/dependency-development\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "708a1667-d1c1-44fc-aea4-2011b73b9b68",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Set up compute resources\n",
    "\n",
    "To set up the compute resources for the project:\n",
    "* Configure the workspace, or head, node with sufficient CPU and memory for task scheduling and coordination, for example, 8 CPUs and 16 GB of memory.\n",
    "* Avoid assigning a GPU to the workspace node, because it doesn't handle training or need GPU resources.\n",
    "* Add worker nodes by specifying both CPU-based and GPU-based instances:\n",
    "    - CPU nodes, for example, 8 CPUs and 16 GB, to handle general processing tasks, set autoscaling from 0 to 10.\n",
    "    - GPU nodes, for example, 1×T4 with 4 CPUs and 16 GB, to accelerate machine learning and deep learning workloads, set autoscaling from 0 to 10.\n",
    "* Employ this hybrid setup to optimize cost and performance by dynamically allocating tasks to the most appropriate resources.\n",
    "\n",
    "### Benefits of using Anyscale\n",
    "* Worker nodes automatically shut down when no training or inference tasks are running, eliminating idle resource costs.\n",
    "* Leverage autoscaling to dynamically allocate tasks to CPU or GPU nodes based on workload demands.\n",
    "* Minimize infrastructure waste by ensuring that GPU resources are only active when required for ML workloads.\n",
    "* Reduce costs by leveraging `Spot instances` for training with massive data. Anyscale also allow fallback to on-demand instances when spot instances aren't available.\n",
    "\n",
    "For more details on setting up compute configs, see: https://docs.anyscale.com/configuration/compute-configuration/\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d1050cc-3be7-4f43-9879-4594e4c56644",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Kaggle data on AWS S3 \n",
    "\n",
    "Anyscale uploaded the Kaggle mask dataset to a publicly available AWS S3 bucket. The original dataset is from Kaggle: https://www.kaggle.com/datasets/andrewmvd/face-mask-detection\n",
    "\n",
    "The dataset is structured into three main folders: `train`, `test`, and `all`:\n",
    "* `all/`:  Contains 853 samples.\n",
    "* `train/` : Contains 682 samples.\n",
    "* `test/`: Contains 171 samples.\n",
    "\n",
    "Each folder contains two subfolders:\n",
    "\n",
    "* `annotations/`: Contains the Pascal VOC XML annotation files. These files include bounding box information and class labels for each image.\n",
    "* `images/`: Contains the actual image files corresponding to the annotations.\n",
    "\n",
    "This structure helps in efficiently managing and processing the data, whether you're training or evaluating your model. The `all` folder typically aggregates all available images and annotations for ease of access."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a30e9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "## Note: Ray train v2 will be available on public Ray very soon, but in the meantime we use this workaround\n",
    "## This will be removed once train v2 is pushed\n",
    "import ray\n",
    "ray.shutdown()\n",
    "ray.init(\n",
    "    runtime_env={\n",
    "        \"env_vars\": {\n",
    "            \"RAY_TRAIN_V2_ENABLED\": \"1\",\n",
    "        },\n",
    "    },\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c7c3ff7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "## Note: Ray train v2 will be available on public Ray very soon, but in the meantime we use this workaround\n",
    "## This will be removed once train v2 is pushed\n",
    "\n",
    "echo \"RAY_TRAIN_V2_ENABLED=1\" > .env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b25b396e",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Note: Ray train v2 will be available on public Ray very soon, but in the meantime we use this workaround\n",
    "## This will be removed once train v2 is pushed\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8be2643b-106b-40fe-9b9d-7a3c5c1f95f2",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Inspect an example image\n",
    "\n",
    "Start by fetching and displaying an example image from the S3 storage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ade7af1c-f596-42ab-9a67-52b8c9f24b54",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import io\n",
    "\n",
    "from PIL import Image\n",
    "import requests\n",
    "\n",
    "response = requests.get(\"https://face-masks-data.s3.us-east-2.amazonaws.com/all/images/maksssksksss0.png\")\n",
    "image = Image.open(io.BytesIO(response.content))\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6633e2e-6d3b-40c1-beff-c92063092b5f",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Inspect an annotation file in Pascal VOC format\n",
    "\n",
    "PASCAL VOC is a widely recognized annotation format for object detection, storing bounding boxes, object classes, and image metadata in XML files. Its structured design and common adoption by popular detection frameworks make it a standard choice for many computer vision tasks. For more details, see: http://host.robots.ox.ac.uk/pascal/VOC/\n",
    "\n",
    "View the annotation for the preceding image, which is stored in Pascal VOC XML format. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e26efb6e-0a24-43ee-aa16-0e599a6bee1d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!curl \"https://face-masks-data.s3.us-east-2.amazonaws.com/all/annotations/maksssksksss0.xml\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1ef737b",
   "metadata": {},
   "source": [
    "\n",
    "Observe some key fields:\n",
    "\n",
    "\n",
    "* The `<size>` contains details about the image dimensions (width, height) and color depth. For instance, the following block indicates that the image is 512 pixels wide, 366 pixels tall, and has 3 color channels, such as RGB. \n",
    "\n",
    "```xml\n",
    "        <size>\n",
    "          <width>512</width>\n",
    "          <height>366</height>\n",
    "          <depth>3</depth>\n",
    "        </size>\n",
    "```\n",
    "\n",
    "\n",
    "* Each `<object>` block describes one annotated object in the image. `<name>` is the label for that object. In this dataset, it can be `with_mask`, `without_mask`, or `mask_weared_incorrect`:\n",
    "\n",
    "* Each `<object>` contains a `<bndbox>` tag, which specifies the coordinates of the bounding box, the rectangle that tightly encloses the object.\n",
    "\n",
    "  - `<xmin>` and `<ymin>` are the top-left corner of the bounding box.\n",
    "  - `<xmax>` and `<ymax>` are the bottom-right corner of the bounding box.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7436d1ff-5223-4711-ae77-bbba5f17077a",
   "metadata": {},
   "source": [
    "### Parse Pascal VOC annotations\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6481c3cf",
   "metadata": {},
   "source": [
    "The annotation files are in XML format; however, since Ray data lacks an XML parser, read the binary files directly from S3 using `ray.data.read_binary_files`.\n",
    "\n",
    "Then, use `parse_voc_annotation` function to extract and parse XML annotation data from a binary input stored in the `bytes` field of a dataset record. It then processes the XML structure to extract bounding box coordinates, object labels, and the filename, returning them as NumPy arrays for further use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d3090ba",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from typing import List, Tuple\n",
    "import xmltodict\n",
    "import numpy as np\n",
    "import ray.data\n",
    "import boto3\n",
    "\n",
    "# # Create a Ray Dataset from the S3 uri.\n",
    "annotation_s3_uri = \"s3://face-masks-data/train/annotations/\"\n",
    "ds = ray.data.read_binary_files(annotation_s3_uri)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b9bde61",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "CLASS_TO_LABEL = {\n",
    "    \"background\": 0,\n",
    "    \"with_mask\": 1,\n",
    "    \"without_mask\": 2,\n",
    "    \"mask_weared_incorrect\": 3\n",
    "}\n",
    "\n",
    "\n",
    "def parse_voc_annotation(record) -> dict:\n",
    "    xml_str = record[\"bytes\"].decode(\"utf-8\")\n",
    "    if not xml_str.strip():\n",
    "        raise ValueError(\"Empty XML string\")\n",
    "        \n",
    "    annotation = xmltodict.parse(xml_str)[\"annotation\"]\n",
    "\n",
    "    # Normalize the object field to a list.\n",
    "    objects = annotation[\"object\"]\n",
    "    if isinstance(objects, dict):\n",
    "        objects = [objects]\n",
    "\n",
    "    boxes: List[Tuple] = []\n",
    "    for obj in objects:\n",
    "        x1 = float(obj[\"bndbox\"][\"xmin\"])\n",
    "        y1 = float(obj[\"bndbox\"][\"ymin\"])\n",
    "        x2 = float(obj[\"bndbox\"][\"xmax\"])\n",
    "        y2 = float(obj[\"bndbox\"][\"ymax\"])\n",
    "        boxes.append((x1, y1, x2, y2))\n",
    "\n",
    "    labels: List[int] = [CLASS_TO_LABEL[obj[\"name\"]] for obj in objects]\n",
    "    filename = annotation[\"filename\"]\n",
    "\n",
    "    return {\n",
    "        \"boxes\": np.array(boxes),\n",
    "        \"labels\": np.array(labels),\n",
    "        \"filename\": filename\n",
    "    }\n",
    "\n",
    "\n",
    "annotations = ds.map(parse_voc_annotation)\n",
    "annotations.take(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f654ac02-3d46-4762-ba36-279c2e686336",
   "metadata": {},
   "source": [
    "### Batch image retrieval from S3\n",
    "Next, fetch images from an S3 URL based on the filenames present in the batch dictionary. For each filename, check if the file has an appropriate image extension, construct the S3 URL, and then download and convert the image to an RGB NumPy array. After that, append all the loaded images into a new key \"image\" within the batch dictionary. \n",
    "\n",
    "Note that in Ray Data, the `map_batches` method only passes the batch of data to your function, meaning you can’t directly supply additional parameters like `images_s3_url`. To work around this, use `partial` to pre-bind the `images_s3_url` argument to your `read_images` function. The `read_images` function then takes just the batch because that’s all `map_batches` provides, and uses the bound URL internally to fetch images from the S3 bucket. \n",
    "\n",
    "Note that you can use either a `function` or a `callable class` to perform the `map` or `map_batches` transformation:\n",
    "* For **functions**, Ray Data uses stateless **Ray tasks**, which are ideal for simple tasks that don’t require loading heavyweight models.\n",
    "* For **classes**, Ray Data uses stateful **Ray actors**, making them well-suited for more complex tasks that involve loading heavyweight models.\n",
    "\n",
    "For more information, see : https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map.html and https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map_batches.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "446f1ef8-831e-40c6-aaa3-ba48a0a1cf2b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from functools import partial\n",
    "\n",
    "\n",
    "def read_images(images_s3_url:str, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:\n",
    "    images: List[np.ndarray] = []\n",
    "    \n",
    "    for filename in batch[\"filename\"]:\n",
    "        \n",
    "        if not filename.lower().endswith((\".png\", \".jpg\", \".jpeg\", \".bmp\", \".gif\")):\n",
    "            continue\n",
    "            \n",
    "        url = os.path.join(images_s3_url, filename)\n",
    "        response = requests.get(url)\n",
    "        image = Image.open(io.BytesIO(response.content)).convert(\"RGB\")  # Ensure image is in RGB.\n",
    "\n",
    "        images.append(np.array(image))\n",
    "    batch[\"image\"] = np.array(images, dtype=object)\n",
    "    return batch\n",
    "\n",
    "\n",
    "# URL for training images stored in S3.\n",
    "train_images_s3_url = \"https://face-masks-data.s3.us-east-2.amazonaws.com/train/images/\"\n",
    "\n",
    "# Bind the URL to your image reading function.\n",
    "train_read_images = partial(read_images, train_images_s3_url)\n",
    "\n",
    "# Map the image retrieval function over your annotations dataset.\n",
    "train_dataset = annotations.map_batches(train_read_images)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee69659b-003c-4a2b-b75d-9ac014a8e97a",
   "metadata": {},
   "source": [
    "### Set up Ray Train for distributed fine-tuning / training\n",
    "\n",
    "This section configures and runs a distributed training loop using Ray Train. The training function handles several essential steps:\n",
    "\n",
    "* **Defining the model**: Initializes a Faster R-CNN model.\n",
    "* **Configuring the optimizer and scheduler**: Sets up the optimizer and learning rate scheduler for training.\n",
    "* **Running the training loop**: Iterates over epochs and batches to update model parameters.\n",
    "* **Checkpointing**: Saves checkpoints, but only on the primary (rank 0) worker to avoid redundant writes.\n",
    "\n",
    "#### Distributed training with Ray Train\n",
    "\n",
    "When launching a distributed training job, each worker executes this training function `train_func`.\n",
    "\n",
    "  - **Without Ray Train**: You would train on a single machine or manually configure PyTorch’s `DistributedDataParallel` to handle data splitting, gradient synchronization, and communication among workers. This setup requires significant manual coordination.\n",
    "\n",
    "  - **With Ray Train:**. Ray Train automatically manages parallelism. It launches multiple training processes (actors), each handling its own shard of the dataset. Under the hood, Ray synchronizes gradients among workers and provides features for checkpointing, metrics reporting, and more. The parallelism primarily occurs at the batch-processing step, with each worker handling a different portion of the data.\n",
    "\n",
    "To learn more about Ray train, see: https://docs.ray.io/en/latest/train/overview.html\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d9985c3-99e1-433d-b271-83a1e449ff49",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "import os\n",
    "import torch\n",
    "from torchvision import models\n",
    "from tempfile import TemporaryDirectory\n",
    "\n",
    "import ray\n",
    "from ray import train\n",
    "\n",
    "from torchvision import transforms \n",
    "import tempfile\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "\n",
    "def train_func(config):\n",
    "    # Get device\n",
    "    device = ray.train.torch.get_device()\n",
    "\n",
    "    # Define model\n",
    "    model = models.detection.fasterrcnn_resnet50_fpn(num_classes=len(CLASS_TO_LABEL))\n",
    "    model = ray.train.torch.prepare_model(model)\n",
    "    \n",
    "    # Define optimizer\n",
    "    parameters = [p for p in model.parameters() if p.requires_grad]\n",
    "    optimizer = torch.optim.SGD(\n",
    "        parameters,\n",
    "        lr=config[\"lr\"],\n",
    "        momentum=config[\"momentum\"],\n",
    "        weight_decay=config[\"weight_decay\"],\n",
    "    )\n",
    "\n",
    "    # Define learning rate scheduler\n",
    "    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(\n",
    "        optimizer, milestones=config[\"lr_steps\"], gamma=config[\"lr_gamma\"]\n",
    "    )\n",
    "\n",
    "\n",
    "    for epoch in range(config[\"epochs\"]):\n",
    "        model.train()\n",
    "\n",
    "        # Warmup learning rate scheduler for first epoch\n",
    "        if epoch == 0:\n",
    "            warmup_factor = 1.0 / 1000\n",
    "            lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n",
    "                optimizer, start_factor=warmup_factor, total_iters=250\n",
    "            )\n",
    "        \n",
    "        # Retrieve the training dataset shard for the current worker.\n",
    "        train_dataset_shard = train.get_dataset_shard(\"train\")\n",
    "        batch_iter = train_dataset_shard.iter_batches(batch_size=config[\"batch_size\"])\n",
    "        batch_iter = tqdm(batch_iter, desc=f\"Epoch {epoch+1}/{config['epochs']}\", unit=\"batch\")\n",
    "\n",
    "\n",
    "        for batch_idx, batch in enumerate(batch_iter):\n",
    "            inputs = [transforms.ToTensor()(image).to(device) for image in batch[\"image\"]]\n",
    "            targets = [\n",
    "                {\n",
    "                    \"boxes\": torch.as_tensor(boxes).to(device),\n",
    "                    \"labels\": torch.as_tensor(labels).to(device),\n",
    "                }\n",
    "                for boxes, labels in zip(batch[\"boxes\"], batch[\"labels\"])\n",
    "            ]\n",
    "            \n",
    "            # Forward pass through the model.\n",
    "            loss_dict = model(inputs, targets)\n",
    "            losses = sum(loss for loss in loss_dict.values())\n",
    "            \n",
    "             # Backpropagation.\n",
    "            optimizer.zero_grad()\n",
    "            losses.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            # Step the learning rate scheduler.\n",
    "            if lr_scheduler is not None:\n",
    "                lr_scheduler.step()\n",
    "            \n",
    "            # Report metrics.\n",
    "            current_worker = ray.train.get_context().get_world_rank()\n",
    "            metrics = {\n",
    "                \"losses\": losses.item(),\n",
    "                \"epoch\": epoch,\n",
    "                \"lr\": optimizer.param_groups[0][\"lr\"],\n",
    "                **{key: value.item() for key, value in loss_dict.items()},\n",
    "            }\n",
    "\n",
    "            # Print batch metrics.\n",
    "            print(f\"Worker {current_worker} - Batch {batch_idx}: {metrics}\")\n",
    "           \n",
    "\n",
    "\n",
    "        if lr_scheduler is not None:\n",
    "            lr_scheduler.step()\n",
    "\n",
    "        # Save a checkpoint on the primary worker for each epoch.\n",
    "        if ray.train.get_context().get_world_rank() == 0:\n",
    "            with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n",
    "                torch.save(\n",
    "                    model.module.state_dict(), os.path.join(temp_checkpoint_dir, \"model.pt\")\n",
    "                )\n",
    "                checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)\n",
    "                train.report(metrics, checkpoint=checkpoint)\n",
    "        else: # Save metrics from all workers for each epoch.\n",
    "            train.report(metrics)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32f13db3",
   "metadata": {},
   "source": [
    "#### How train.get_dataset_shard(\"train\") works\n",
    "\n",
    "A shard is a partition of the overall dataset allocated to a specific worker. For example, if you have 4 workers and 10,000 images, each worker receives 2,500 images, that is, one shard of 2,500 each.\n",
    "\n",
    "Ray Train automatically splits your dataset into shards across multiple workers. Calling `train.get_dataset_shard(\"train\")` returns the subset (shard) of the dataset for the current worker. Each worker trains on a different shard in parallel. This approach contrasts with a typical single-machine PyTorch setup, where you might rely on PyTorch’s DataLoader or a DistributedSampler for data distribution. For more details: https://docs.ray.io/en/latest/train/api/doc/ray.train.get_dataset_shard.html\n",
    "\n",
    "\n",
    "#### Batch size\n",
    "\n",
    "The batch size specifies how many samples each worker processes in a single forward/backward pass. For instance, a batch size of 4 means each training step processes 4 samples within that worker’s shard before performing a gradient update. In practice, you should carefully select the batch size based on the model size and GPU memory size. \n",
    "\n",
    "#### Checkpointing on the primary (rank 0) worker\n",
    "\n",
    "In this example, all workers maintain the same model parameters. They're kept in sync during updates. Therefore, by the end of each epoch, or at checkpoint time, every worker’s model state is identical. Saving checkpoints from only the primary worker (rank 0) prevents redundant or conflicting writes and ensures one clear, consistent checkpoint.\n",
    "\n",
    "To learn more about saving and loading checkpoints, see:https://docs.ray.io/en/latest/train/user-guides/checkpoints.html\n",
    "\n",
    "#### Reporting metrics for all worker nodes\n",
    "\n",
    "Use `train.report` to track metrics from **all worker nodes**. Ray Train’s internal bookkeeping records these metrics, enabling you to monitor progress and analyze results after training completes. \n",
    "\n",
    "**Note: You receive errors if you only report the metrics from the primary worker, a common mistake to avoid.** "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2360f2b5-f66c-47c4-9a7f-9705bb497699",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Launch the fine-tuning / training process with TorchTrainer\n",
    "\n",
    "Configure and initiate training using TorchTrainer from Ray Train. Be patient, as this process may take some time.\n",
    "\n",
    "**For demonstration purposes, set `epochs` to 2, but the performance of the fine-tuned model won't be optimal.** In practice, you would typically train for 20-30 epochs to achieve a well fine-tuned model.\n",
    "\n",
    "The `num_workers` parameter specifies how many parallel worker processes that Ray starts for data-parallel training. Set `num_workers=2` for demonstration purposes, but in real scenarios, the setting depends on:\n",
    "\n",
    "* Your max number of available GPUs: Ray can assign each worker to one GPU, if use_gpu=True. Hence, if you have 4 GPUs, you could set num_workers=4.\n",
    "* Desired training speed: More workers can lead to faster training because Ray Train splits the workload among multiple devices or processes. If your training data is large and you have the computational resources, you can increase `num_workers` to accelerate training.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8536cff2-915e-40fa-b22d-37d4e84563a1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "from ray.train.torch import TorchTrainer\n",
    "\n",
    "\n",
    "storage_path = \"/mnt/cluster_storage/face-mask-experiments_v1/\"\n",
    "run_config = ray.train.RunConfig(storage_path=storage_path, name=\"face-mask-experiments_v1\")\n",
    "\n",
    "trainer = TorchTrainer(\n",
    "    train_func,\n",
    "    train_loop_config={\n",
    "        \"batch_size\": 4, # ajust it based on your GPU memory, a batch size that is too large could cause OOM issue\n",
    "        \"lr\": 0.02,\n",
    "        \"epochs\": 2,  # You'd normally train for 20-30 epochs to get a good performance.\n",
    "        \"momentum\": 0.9,\n",
    "        \"weight_decay\": 1e-4,\n",
    "        \"lr_steps\": [16, 22],\n",
    "        \"lr_gamma\": 0.1,\n",
    "    },\n",
    "    scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True),\n",
    "    run_config = run_config,\n",
    "    datasets={\"train\": train_dataset},\n",
    ")\n",
    "\n",
    "results = trainer.fit()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "004da509-6c51-4be0-b7b6-e7b0ae5981ab",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Inspect results when training completes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14dc494c-fff4-4a22-bc99-3611ed710a29",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "\n",
    "\n",
    "print(\"Metrics reported during training:\")\n",
    "print(results.metrics)\n",
    "\n",
    "print(\"\\nLatest checkpoint reported during training:\")\n",
    "print(results.checkpoint)\n",
    "\n",
    "print(\"\\nPath where logs are stored:\")\n",
    "print(results.path)\n",
    "\n",
    "print(\"\\nException raised, if training failed:\")\n",
    "print(results.error)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d9c98df-95f8-48e5-9027-b32fa6594771",
   "metadata": {},
   "source": [
    "### Run inference and visualize predictions on a test image\n",
    "After training, run the model on a single test image for a sanity check:\n",
    "\n",
    "* Download an image from a URL.\n",
    "* Run the model for predictions.\n",
    "* Visualize the detections (bounding boxes and labels).\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86801926-1351-4f84-8ca3-2827db3d4828",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import io\n",
    "import requests\n",
    "import numpy as np\n",
    "import torch\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "\n",
    "# CLASS_TO_LABEL dictionary\n",
    "CLASS_TO_LABEL = {\n",
    "    \"background\": 0,\n",
    "    \"with_mask\": 1,\n",
    "    \"without_mask\": 2,\n",
    "    \"mask_weared_incorrect\": 3\n",
    "}\n",
    "\n",
    "# Create reverse label mapping\n",
    "LABEL_TO_CLASS = {v: k for k, v in CLASS_TO_LABEL.items()}\n",
    "\n",
    "# Define colors for each category\n",
    "LABEL_COLORS = {\n",
    "    \"with_mask\": \"green\",\n",
    "    \"without_mask\": \"red\",\n",
    "    \"mask_weared_incorrect\": \"yellow\"\n",
    "}\n",
    "\n",
    "def load_image_from_url(url):\n",
    "    \"\"\"\n",
    "    Downloads the image from the given URL and returns it as a NumPy array.\n",
    "    \"\"\"\n",
    "    response = requests.get(url)\n",
    "    response.raise_for_status()  # Raise an error if the download failed.\n",
    "    image = Image.open(io.BytesIO(response.content)).convert('RGB')\n",
    "    return np.array(image)\n",
    "\n",
    "def predict_and_visualize(image_np, model, confidence_threshold=0.5):\n",
    "    \"\"\"Run model prediction on an image array and visualize results.\"\"\"\n",
    "    # Convert numpy array to PIL Image.\n",
    "    image_pil = Image.fromarray(image_np)\n",
    "    draw = ImageDraw.Draw(image_pil)\n",
    "    font = ImageFont.load_default()\n",
    "\n",
    "    # Preprocess image for model.\n",
    "    image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float() / 255.0\n",
    "\n",
    "    # Make prediction.\n",
    "    with torch.no_grad():\n",
    "        predictions = model([image_tensor])[0]  # Get first (and only) prediction\n",
    "\n",
    "    # Filter predictions by confidence.\n",
    "    keep = predictions['scores'] > confidence_threshold\n",
    "    boxes = predictions['boxes'][keep]\n",
    "    labels = predictions['labels'][keep]\n",
    "    scores = predictions['scores'][keep]\n",
    "\n",
    "    # Draw each detection.\n",
    "    for box, label, score in zip(boxes, labels, scores):\n",
    "        x1, y1, x2, y2 = box.tolist()\n",
    "        \n",
    "        # Convert numeric label back to class name.\n",
    "        class_name = LABEL_TO_CLASS.get(label.item(), \"unknown\")\n",
    "        \n",
    "        # Get corresponding color.\n",
    "        box_color = LABEL_COLORS.get(class_name, \"white\")  # Default to white if unknown.\n",
    "        \n",
    "        # Draw bounding box.\n",
    "        draw.rectangle([x1, y1, x2, y2], outline=box_color, width=2)\n",
    "        \n",
    "        # Prepare text.\n",
    "        text = f\"{class_name} {score:.2f}\"\n",
    "        \n",
    "        # Calculate text size.\n",
    "        text_bbox = draw.textbbox((0, 0), text, font=font)\n",
    "        text_width = text_bbox[2] - text_bbox[0]\n",
    "        text_height = text_bbox[3] - text_bbox[1]\n",
    "        \n",
    "        # Draw text background.\n",
    "        draw.rectangle(\n",
    "            [x1, y1 - text_height - 2, x1 + text_width, y1],\n",
    "            fill=box_color\n",
    "        )\n",
    "        \n",
    "        # Draw text.\n",
    "        draw.text(\n",
    "            (x1, y1 - text_height - 2),\n",
    "            text,\n",
    "            fill=\"black\" if box_color in [\"yellow\"] else \"white\",  # Ensure good contrast\n",
    "            font=font\n",
    "        )\n",
    "\n",
    "    return image_pil\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94254df9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Load model.\n",
    "ckpt = results.checkpoint\n",
    "with ckpt.as_directory() as ckpt_dir:\n",
    "    model_path = os.path.join(ckpt_dir, \"model.pt\")\n",
    "    model = models.detection.fasterrcnn_resnet50_fpn(num_classes=len(CLASS_TO_LABEL))\n",
    "    state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)\n",
    "    model.load_state_dict(state_dict)\n",
    "    model.eval()\n",
    "\n",
    "# URL for a test image.\n",
    "url = \"https://face-masks-data.s3.us-east-2.amazonaws.com/all/images/maksssksksss0.png\"\n",
    "\n",
    "# Load image from URL.\n",
    "image_np = load_image_from_url(url)\n",
    "\n",
    "# Run prediction and visualization.\n",
    "result_image = predict_and_visualize(image_np, model, confidence_threshold=0.7)\n",
    "result_image.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8da3b9d",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-warning\"> <b> Note: You may notice that the results aren't optimal because you trained for only 2 epochs. \n",
    "Typically, training would require around 20 epochs.</b> \n",
    "<div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "524678d9-6a37-4345-9788-7940568dffee",
   "metadata": {},
   "source": [
    "### Store the trained model locally\n",
    "\n",
    "After training, you can access the checkpoint, load the model weights, and save the model locally in your workspace. This allows you to easily download the model to your local machine, inspect the model, or do a sanity check. **Don't load the model and run batch inference directly from the workspace**, as this forces the Ray cluster to copy the weights to other nodes, significantly slowing down the process. To enable faster batch inference, use Anyscale’s cluster storage to store the model instead.\n",
    "\n",
    "```python\n",
    "ckpt = results.checkpoint\n",
    "with ckpt.as_directory() as ckpt_dir:\n",
    "    model_path = os.path.join(ckpt_dir, \"model.pt\")\n",
    "    model = models.detection.fasterrcnn_resnet50_fpn(num_classes=len(CLASS_TO_LABEL))\n",
    "    state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)\n",
    "    model.load_state_dict(state_dict)\n",
    "    model.eval()\n",
    "\n",
    "# Save the model locally.\n",
    "save_path = \"./saved_model/fasterrcnn_model_mask_detection.pth\"  # Choose your path.\n",
    "os.makedirs(os.path.dirname(save_path), exist_ok=True)  # Create directory if needed.\n",
    "torch.save(model.state_dict(), save_path)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd751b10",
   "metadata": {},
   "source": [
    "### Store the model on Anyscale cluster storage\n",
    "You can store your model on Anyscale cluster storage, `/mnt/cluster_storage`, for faster batch inference or serving on Anyscale. If multiple worker nodes need to access the model in a distributed computing environment, storing it in cluster storage ensures all nodes load the model quickly and avoids redundant copies.\n",
    "\n",
    "For more information, see: https://docs.anyscale.com/configuration/storage/\n",
    "\n",
    "\n",
    "```python\n",
    "ckpt = results.checkpoint\n",
    "with ckpt.as_directory() as ckpt_dir:\n",
    "    model_path = os.path.join(ckpt_dir, \"model.pt\")\n",
    "    model = models.detection.fasterrcnn_resnet50_fpn(num_classes=len(CLASS_TO_LABEL))\n",
    "    state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)\n",
    "    model.load_state_dict(state_dict)\n",
    "    model.eval()\n",
    "\n",
    "# Save the model locally\n",
    "save_path = \"/mnt/cluster_storage/fasterrcnn_model_mask_detection.pth\"  # Choose your path\n",
    "os.makedirs(os.path.dirname(save_path), exist_ok=True)  # Create directory if needed\n",
    "torch.save(model.state_dict(), save_path)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "271522db-7bc3-4227-ae99-8344019ffd11",
   "metadata": {},
   "source": [
    "### Store the model in the cloud\n",
    "You can store your model in a cloud such as AWS S3, Google Cloud Storage, or Hugging Face. Store the model remotely on a cloud helps your team collaboration, versioning, and efficient deployment and inference. Later on, you can use `smart-open` to load the model from AWS S3, Google Cloud Storage, or use AutoModel to load the model from Hugging Face. See how to load the model from AWS S3 in the next notebook.\n",
    "\n",
    "This sample code uploads your model to AWS S3. Be sure to install the boto3 library properly configure it with AWS credentials:\n",
    "\n",
    "```python\n",
    "import os\n",
    "import torch\n",
    "import boto3\n",
    "import smart_open\n",
    "from torchvision import models\n",
    "\n",
    "# Define S3 details\n",
    "S3_BUCKET = \"your-s3-bucket-name\"\n",
    "S3_KEY = \"path/in/s3/fasterrcnn_model_mask_detection.pth\"\n",
    "S3_URI = f\"s3://{S3_BUCKET}/{S3_KEY}\"\n",
    "\n",
    "# Load the model checkpoint\n",
    "ckpt = results.checkpoint\n",
    "with ckpt.as_directory() as ckpt_dir:\n",
    "    model_path = os.path.join(ckpt_dir, \"model.pt\")\n",
    "    model = models.detection.fasterrcnn_resnet50_fpn(num_classes=len(CLASS_TO_LABEL))\n",
    "    state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)\n",
    "    model.load_state_dict(state_dict)\n",
    "    model.eval()\n",
    "\n",
    "# Upload to S3 directly using smart_open\n",
    "try:\n",
    "    with smart_open.open(S3_URI, \"wb\") as f:\n",
    "        torch.save(model.state_dict(), f)\n",
    "    print(f\"Model successfully uploaded to {S3_URI}\")\n",
    "except Exception as e:\n",
    "    print(f\"Error uploading to S3: {e}\")\n",
    "\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a330db6",
   "metadata": {},
   "source": [
    "## Clean up the cluster storage\n",
    "\n",
    "You can see the files you stored in the cluster storage. You can see that you created `/mnt/cluster_storage/face-mask-experiments_v1/` to store the training artifacts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8d23b9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls -lah /mnt/cluster_storage/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebd1cb97",
   "metadata": {},
   "source": [
    "**Remember to clean up the cluster storage by removing it:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8800c731",
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf /mnt/cluster_storage/face-mask-experiments_v1/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfb00a72",
   "metadata": {},
   "source": [
    "## Next steps\n",
    "\n",
    "For the following notebooks, **Anyscale has already uploaded a fine-tuned mask detection model with a batch size of 20, to AWS S3**. The following notebook demonstrates how to download the model to an Anyscale cluster for batch inference, among other tasks.\n",
    "\n",
    "However, feel free to use your own fine-tuned model (around 20 epochs) if you prefer."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
